"""tags table [6119cd9b93c2].

Revision ID: 6119cd9b93c2
Revises: 0.46.1
Create Date: 2023-11-08 11:52:16.440417

"""

import json
import random
from collections import defaultdict
from datetime import datetime
from typing import Set
from uuid import uuid4

import sqlalchemy as sa
import sqlmodel
from alembic import op

# revision identifiers, used by Alembic.
revision = "6119cd9b93c2"
down_revision = "0.46.1"
branch_labels = None
depends_on = None


def upgrade() -> None:
    """Upgrade database schema and/or data, creating a new revision."""
    # ### commands auto generated by Alembic - please adjust! ###
    from zenml.enums import ColorVariants, TaggableResourceTypes

    bind = op.get_bind()
    session = sqlmodel.Session(bind=bind)

    op.create_table(
        "tag",
        sa.Column("color", sa.VARCHAR(length=255), nullable=False),
        sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
        sa.Column("created", sa.DateTime(), nullable=False),
        sa.Column("updated", sa.DateTime(), nullable=False),
        sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
        sa.PrimaryKeyConstraint("id"),
    )

    # fetch current tags in Model table
    model_tags = session.execute(
        sa.text(
            """
            SELECT id, tags
            FROM model
            WHERE tags IS NOT NULL
            """
        )
    )
    # find unique tags and de-json tags
    unique_tags: Set[str] = set()
    model_tags_prepared = []
    for id_, tags in model_tags:
        try:
            model_tags_prepared.append((id_, set(json.loads(tags))))
            unique_tags = unique_tags.union(model_tags_prepared[-1][1])
        except json.JSONDecodeError:
            continue
    # if some tags exist insert it into new Tag table
    if unique_tags:
        insert_tags = """
        INSERT INTO tag
        (id, name, color, created, updated)
        VALUES 
        """
        now = str(datetime.now())
        for tag in unique_tags:
            insert_tags += f"('{uuid4().hex}', '{tag}', '{random.choice(list(ColorVariants)).value}', '{now}', '{now}'),"
        insert_tags = insert_tags[:-1] + ";"
        session.execute(sa.text(insert_tags))

    op.create_table(
        "tag_resource",
        sa.Column("tag_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
        sa.Column("resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
        sa.Column("resource_type", sa.VARCHAR(length=255), nullable=False),
        sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
        sa.Column("created", sa.DateTime(), nullable=False),
        sa.Column("updated", sa.DateTime(), nullable=False),
        sa.ForeignKeyConstraint(
            ["tag_id"],
            ["tag.id"],
            name="fk_tag_resource_tag_id_tag",
            ondelete="CASCADE",
        ),
        sa.PrimaryKeyConstraint("id"),
    )

    # if some tags exist insert relation to Models in new table
    if unique_tags:
        tags_with_ids = session.execute(sa.text("SELECT id, name FROM tag"))
        tags_ids_mapping = {}
        for id_, name in tags_with_ids:
            tags_ids_mapping[name] = id_

        insert_tag_models = """
        INSERT INTO tag_resource
        (id, tag_id, resource_id, resource_type, created, updated)
        VALUES 
        """
        now = str(datetime.now())
        for model_id_, tags_in_model in model_tags_prepared:
            for tag in tags_in_model:
                insert_tag_models += (
                    f"('{uuid4().hex}', "
                    f"'{tags_ids_mapping[tag]}', '{model_id_}', "
                    f"'{TaggableResourceTypes.MODEL.value}', '{now}', '{now}'),"
                )
        insert_tag_models = insert_tag_models[:-1] + ";"
        session.execute(sa.text(insert_tag_models))

    session.commit()

    with op.batch_alter_table("model", schema=None) as batch_op:
        batch_op.drop_column("tags")

    # ### end Alembic commands ###


def downgrade() -> None:
    """Downgrade database schema and/or data back to the previous revision."""
    # ### commands auto generated by Alembic - please adjust! ###
    bind = op.get_bind()
    session = sqlmodel.Session(bind=bind)

    with op.batch_alter_table("model", schema=None) as batch_op:
        batch_op.add_column(sa.Column("tags", sa.TEXT(), nullable=True))

    # fetch model-tag pairs
    model_tags = session.execute(
        sa.text(
            """
            SELECT resource_id,t.name
            FROM tag_resource as tr
            JOIN tag as t ON tr.tag_id = t.id
            """
        )
    )
    # aggregate by model_id
    model_tags_combined = defaultdict(set)
    for model_id_, tag_ in model_tags:
        model_tags_combined[model_id_].add(tag_)
    update_query = sa.text(
        """
        UPDATE model
        SET tags = :tags
        WHERE id = :model_id
        """
    )
    # push them as string lists to Model table
    for model_id_, tags_ in model_tags_combined.items():
        session.execute(
            update_query,
            params=dict(tags=json.dumps(list(tags_)), model_id=model_id_),
        )
    session.commit()

    op.drop_table("tag_resource")
    op.drop_table("tag")
    # ### end Alembic commands ###
